{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial 3: Implementing CycleGAN\n",
    "\n",
    "**Author** - [Avik Pal](https://avik-pal.github.io/)\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/torchgan/torchgan/blob/master/tutorials/Tutorial%203.%20CycleGAN.ipynb)\n",
    "\n",
    "This is an **Intermediate Tutorial** and it is recommended to go through the previous set of tutorials which cover the basics of torchgan. In this tutorial we shall be implementing the paper [\n",
    "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks by Zhu et. al.](https://arxiv.org/abs/1703.10593).\n",
    "\n",
    "Since this is a fairly advanced model, we will have to subclass most of the features provided by torchgan. The tutorial can be broadly categorized into 3 major components:\n",
    "\n",
    "1. The Generator and the Discriminator Models\n",
    "2. The Loss Functions\n",
    "3. A Custom Trainer\n",
    "\n",
    "We also need to write our own DataLoader for the **facades** dataset. It is fairly simple to modify the DataLoader to support other datasets and we would highly encourage you to do so in order to get more familiar with the way *torchgan and Pytorch in general work*."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tutorial assumes that your system has **PyTorch** and **TorchGAN** installed properly. If not, the following code block will try to install the **latest tagged version** of TorchGAN. If you need to use some other version head over to the installation instructions on the [official documentation website](https://torchgan.readthedocs.io/en/latest/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import torchgan\n",
    "\n",
    "    print(f\"Existing TorchGAN {torchgan.__version__} installation found\")\n",
    "except ImportError:\n",
    "    import subprocess\n",
    "    import sys\n",
    "\n",
    "    subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"torchgan\"])\n",
    "    import torchgan\n",
    "\n",
    "    print(f\"Installed TorchGAN {torchgan.__version__}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys, glob, random\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.optim import Adam\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from PIL import Image\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.utils as vutils\n",
    "import torchgan\n",
    "from torchgan.models import Generator, Discriminator\n",
    "from torchgan.trainer import Trainer\n",
    "from torchgan.losses import (\n",
    "    GeneratorLoss,\n",
    "    DiscriminatorLoss,\n",
    "    least_squares_generator_loss,\n",
    "    least_squares_discriminator_loss,\n",
    ")\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set random seed for reproducibility\n",
    "manualSeed = 999\n",
    "random.seed(manualSeed)\n",
    "torch.manual_seed(manualSeed)\n",
    "print(\"Random Seed: \", manualSeed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LOADING THE DATASET\n",
    "\n",
    "We will try to download the **Facades Data** in the required format. In case the setup fails for any reason use the following instructions to setup the data (and [open an issue](https://github.com/torchgan/torchgan/issues/new?assignees=&labels=bug&template=bug_report.md&title=)).\n",
    "\n",
    "<br>\n",
    "<details><summary><b>Click to view Data Setup Instructions</b></summary>\n",
    "\n",
    "We are going to work with the **Facades Data**. You should use [this script](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/datasets/download_cyclegan_dataset.sh) to download the data.\n",
    "\n",
    "The downloaded data should be present in the following format:\n",
    "\n",
    "```\n",
    "    ---> facades\n",
    "        ---> trainA\n",
    "            ---> ...\n",
    "            ---> ...\n",
    "                .\n",
    "                .\n",
    "                .\n",
    "        ---> trainB\n",
    "            ---> ...\n",
    "            ---> ...\n",
    "                .\n",
    "                .\n",
    "                .\n",
    "        ---> testA\n",
    "            ---> ...\n",
    "            ---> ...\n",
    "                .\n",
    "                .\n",
    "                .\n",
    "        ---> testB\n",
    "            ---> ...\n",
    "            ---> ...\n",
    "                .\n",
    "                .\n",
    "                .\n",
    "```\n",
    "The `...` signify the image files. For the sake of this tutorial we will only make use of the `train` files, though ideally for inference you should be using the `test` files.\n",
    "\n",
    "</details>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import os\n",
    "from zipfile import ZipFile\n",
    "\n",
    "print(\"Beginning file download\")\n",
    "\n",
    "url = \"https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/facades.zip\"\n",
    "r = requests.get(url)\n",
    "\n",
    "with open(\"./facades.zip\", \"wb\") as f:\n",
    "    f.write(r.content)\n",
    "\n",
    "with ZipFile(\"facades.zip\", \"r\") as zipObj:\n",
    "    # Extract all the contents of zip file in different directory\n",
    "    zipObj.extractall(\"datasets\")\n",
    "\n",
    "print(\"Dataset download successful\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can observe that the `__getitem__` function returns a `dict`. This is a vital detail. The reson for this shall be explained when we will subclass the `Trainer`. But remember that sending a `list` or `tuple` will create issues for our code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ImageDataset(Dataset):\n",
    "    def __init__(self, root, transform=None, mode=\"train\"):\n",
    "        self.transform = transform\n",
    "        self.files_A = sorted(\n",
    "            glob.glob(os.path.join(root, \"{}A\".format(mode)) + \"/*.*\")\n",
    "        )\n",
    "        self.files_B = sorted(\n",
    "            glob.glob(os.path.join(root, \"{}B\".format(mode)) + \"/*.*\")\n",
    "        )\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))\n",
    "        item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]))\n",
    "        return {\"A\": item_A, \"B\": item_B}\n",
    "\n",
    "    def __len__(self):\n",
    "        return max(len(self.files_A), len(self.files_B))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = ImageDataset(\n",
    "    \"./datasets/facades\",\n",
    "    transform=transforms.Compose(\n",
    "        [\n",
    "            transforms.CenterCrop((64, 64)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),\n",
    "        ]\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This `batch_size` works on a **12GB Nvidia 1080Ti GPU**. If you are trying this notebook in a less powerful machine please reduce the batch_size otherwise you will most likely encounter a CUDA Out Of Memory Issue."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## VISUALIZE THE TRAINING DATA\n",
    "\n",
    "Lets display the pictures that will serve as our training dataset.\n",
    "We will simultaneously train 2 generators, **gen_a2b** and **gen_b2a**. **gen_a2b** will learn to translate images of **type A** (the normal images) to images of **type B** (the segmented images). **gen_b2a** will learn to do the reverse translation.\n",
    "\n",
    "**Note**: *Even though we are using paired images in this tutorial, CycleGAN can work with Unpaired Data. For some datasets you shall not have paired images, so feel free to use unpaired data. Only thing to take into account would be to use some randomization while selected the image pair in such a scenario. The easiest way would be to replace the selection line for image_B to `np.random.randint(0, len(self.files_B))`*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = next(iter(dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 8))\n",
    "plt.axis(\"off\")\n",
    "plt.title(\"Images A\")\n",
    "plt.imshow(\n",
    "    np.transpose(\n",
    "        vutils.make_grid(\n",
    "            a[\"A\"].to(torch.device(\"cuda:0\"))[:64], padding=2, normalize=True\n",
    "        ).cpu(),\n",
    "        (1, 2, 0),\n",
    "    )\n",
    ")\n",
    "plt.show()\n",
    "plt.figure(figsize=(8, 8))\n",
    "plt.axis(\"off\")\n",
    "plt.title(\"Images B\")\n",
    "plt.imshow(\n",
    "    np.transpose(\n",
    "        vutils.make_grid(\n",
    "            a[\"B\"].to(torch.device(\"cuda:0\"))[:64], padding=2, normalize=True\n",
    "        ).cpu(),\n",
    "        (1, 2, 0),\n",
    "    )\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DEFINING THE GENERATOR & DISCRIMINATOR\n",
    "\n",
    "First we will be defining the building blocks of the model. TorchGAN provides standard ResidualBlocks however, we need a specific form of ResidualBlock for CycleGAN model. In the paper [\n",
    "Instance Normalization: The Missing Ingredient for Fast Stylization by Ulyanov et. al.](https://arxiv.org/abs/1607.08022), the authors describe the use of Instance Normalization for Style Transfer. On a similar context, we shall be using Instance Norm instead of Batch Norm and finally swap the **Zero Padding** of the Convolutional Layer with **Reflection Padding**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResidualBlock(nn.Module):\n",
    "    def __init__(self, in_features):\n",
    "        super(ResidualBlock, self).__init__()\n",
    "        self.conv_block = nn.Sequential(\n",
    "            nn.ReflectionPad2d(1),\n",
    "            nn.Conv2d(in_features, in_features, 3),\n",
    "            nn.InstanceNorm2d(in_features),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.ReflectionPad2d(1),\n",
    "            nn.Conv2d(in_features, in_features, 3),\n",
    "            nn.InstanceNorm2d(in_features),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return x + self.conv_block(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The **CycleGAN Generator** has 3 parts:\n",
    "\n",
    "1. A downsampling network: It is composed of 3 convolutional layers  (together with the regular padding, normalization and activation layers).\n",
    "2. A chain of residual networks built using the Residual Block. You can try to vary the `res_blocks` parameter and see the results.\n",
    "3. A upsampling network: It is composed of 3 transposed convolutional layers."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also need to define a `sampler` function which provides a visual standard for seeing how the generator is performing. The sampler must receive 2 inputs `sample_size` and `device` and it should return a list of the arguments needed by the `forward` function of the generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CycleGANGenerator(Generator):\n",
    "    def __init__(self, image_batch, in_channels=3, out_channels=3, res_blocks=5):\n",
    "        super(CycleGANGenerator, self).__init__(in_channels)\n",
    "\n",
    "        self.image_batch = image_batch\n",
    "\n",
    "        # Initial convolution block\n",
    "        model = [\n",
    "            nn.ReflectionPad2d(3),\n",
    "            nn.Conv2d(in_channels, 64, 7),\n",
    "            nn.InstanceNorm2d(64),\n",
    "            nn.ReLU(inplace=True),\n",
    "        ]\n",
    "\n",
    "        # Downsampling\n",
    "        in_features = 64\n",
    "        out_features = in_features * 2\n",
    "        for _ in range(2):\n",
    "            model += [\n",
    "                nn.Conv2d(in_features, out_features, 4, stride=2, padding=1),\n",
    "                nn.InstanceNorm2d(out_features),\n",
    "                nn.ReLU(inplace=True),\n",
    "            ]\n",
    "            in_features = out_features\n",
    "            out_features = in_features * 2\n",
    "\n",
    "        # Residual blocks\n",
    "        for _ in range(res_blocks):\n",
    "            model += [ResidualBlock(in_features)]\n",
    "\n",
    "        # Upsampling\n",
    "        out_features = in_features // 2\n",
    "        for _ in range(2):\n",
    "            model += [\n",
    "                nn.ConvTranspose2d(in_features, out_features, 4, stride=2, padding=1),\n",
    "                nn.InstanceNorm2d(out_features),\n",
    "                nn.ReLU(inplace=True),\n",
    "            ]\n",
    "            in_features = out_features\n",
    "            out_features = in_features // 2\n",
    "\n",
    "        # Output layer\n",
    "        model += [nn.ReflectionPad2d(3), nn.Conv2d(64, out_channels, 7), nn.Tanh()]\n",
    "\n",
    "        self.model = nn.Sequential(*model)\n",
    "\n",
    "        self._weight_initializer()\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.model(x)\n",
    "\n",
    "    def sampler(self, sample_size, device):\n",
    "        return [self.image_batch.to(device)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The **CycleGAN Discriminator** is like the standard DCGAN Discriminator. The only difference is the normalization used. Just like in the Generator we shall be using Instance Normalization even in the Discriminator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CycleGANDiscriminator(Discriminator):\n",
    "    def __init__(self, in_channels=3):\n",
    "        super(Discriminator, self).__init__()\n",
    "\n",
    "        def discriminator_block(in_filters, out_filters, normalize=True):\n",
    "            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]\n",
    "            if normalize:\n",
    "                layers.append(nn.InstanceNorm2d(out_filters))\n",
    "            layers.append(nn.LeakyReLU(0.2, inplace=True))\n",
    "            return layers\n",
    "\n",
    "        self.model = nn.Sequential(\n",
    "            *discriminator_block(in_channels, 64, normalize=False),\n",
    "            *discriminator_block(64, 128),\n",
    "            *discriminator_block(128, 256),\n",
    "            *discriminator_block(256, 512),\n",
    "            nn.ZeroPad2d((1, 0, 1, 0)),\n",
    "            nn.Conv2d(512, 1, 4, padding=1)\n",
    "        )\n",
    "\n",
    "        self._weight_initializer()\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.model(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LOSS FUNCTIONS\n",
    "\n",
    "The Generator Loss is composed of 3 parts. They are described below:\n",
    "\n",
    "1. **GAN Loss**: It is the standard generator loss of the Least Squares GAN. We use the functional forms of the losses to implement this part.\n",
    "$$L_{GAN} = \\frac{1}{4} \\times ((D_A(G_{B2A}(Image_B)) - 1)^2 + (D_B(G_{A2B}(Image_A)) - 1)^2)$$\n",
    "2. **Identity Loss**: It computes the similarity of a real image of type B and a fake image B generated from image A and vice versa. The similarity is measured using the $L_1$ Loss.\n",
    "$$L_{identity} = \\frac{1}{2} \\times (||G_{B2A}(Image_B) - Image_A||_1 + ||G_{A2B}(Image_A) - Image_B||_1)$$\n",
    "3. **Cycle Consistency Loss**: This loss computes the similarity of the original image and the image generated by a composition of the 2 generators. This allows cyclegan to deak with unpaired images. We reconstruct the original image and try to minimize the $L_1$ norm between the original images and this reconstructed image.\n",
    "$$L_{cycle\\_consistency} = \\frac{1}{2} \\times (||G_{B2A}(G_{A2B}(Image_A)) - Image_A||_1 + ||G_{A2B}(G_{B2A}(Image_B)) - Image_B||_1)$$\n",
    "\n",
    "The losses can be decomposed into 3 different loss functions, however, doing that would not be in our best interests. In that case we shall be backpropagating 3 times through the networks. This will lead to a huge impact in the performance of your code. So the general rule in torchgan is to club losses together if they improve the performance of your model otherwise keep them seperate (this will lead to better loss visualization) and feed them in through the losses list.\n",
    "\n",
    "Now let us see the naming convention for the `train_ops` arguments. We simply list all the variables stored in the Trainer that we need. We are guaranteed to get all these variables if they are present in the Trainer. In case something is not, you shall receive a well defined error message stating which argument was not found. Then you can define that argument or use the `set_arg_map` to fix that. The details of this method is clearly demonstrated in the documentation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CycleGANGeneratorLoss(GeneratorLoss):\n",
    "    def train_ops(\n",
    "        self,\n",
    "        gen_a2b,\n",
    "        gen_b2a,\n",
    "        dis_a,\n",
    "        dis_b,\n",
    "        optimizer_gen_a2b,\n",
    "        optimizer_gen_b2a,\n",
    "        image_a,\n",
    "        image_b,\n",
    "    ):\n",
    "        optimizer_gen_a2b.zero_grad()\n",
    "        optimizer_gen_b2a.zero_grad()\n",
    "        fake_a = gen_b2a(image_b)\n",
    "        fake_b = gen_a2b(image_a)\n",
    "        loss_identity = 0.5 * (F.l1_loss(fake_a, image_a) + F.l1_loss(fake_b, image_b))\n",
    "        loss_gan = 0.5 * (\n",
    "            least_squares_generator_loss(dis_a(fake_a))\n",
    "            + least_squares_generator_loss(dis_b(fake_b))\n",
    "        )\n",
    "        loss_cycle_consistency = 0.5 * (\n",
    "            F.l1_loss(gen_a2b(fake_a), image_b) + F.l1_loss(gen_b2a(fake_b), image_a)\n",
    "        )\n",
    "        loss = loss_identity + loss_gan + loss_cycle_consistency\n",
    "        loss.backward()\n",
    "        optimizer_gen_a2b.step()\n",
    "        optimizer_gen_b2a.step()\n",
    "        return loss.item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Discriminator as mentioned before is same as the normal DCGAN Discriminator. As such even the loss function for that is same as that of the standard GAN. Again we list all the required variables in the train_ops and we are guaranteed to get those from the Trainer.\n",
    "\n",
    "$$L_{GAN} = \\frac{1}{4} \\times (((D_A(Image_A) - 1)^2 - (D_A(G_{B2A}(Image_B))^2) + ((D_B(Image_B) - 1)^2 - (D_B(G_{A2B}(Image_A))^2))$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CycleGANDiscriminatorLoss(DiscriminatorLoss):\n",
    "    def train_ops(\n",
    "        self,\n",
    "        gen_a2b,\n",
    "        gen_b2a,\n",
    "        dis_a,\n",
    "        dis_b,\n",
    "        optimizer_dis_a,\n",
    "        optimizer_dis_b,\n",
    "        image_a,\n",
    "        image_b,\n",
    "    ):\n",
    "        optimizer_dis_a.zero_grad()\n",
    "        optimizer_dis_b.zero_grad()\n",
    "        fake_a = gen_b2a(image_b).detach()\n",
    "        fake_b = gen_a2b(image_a).detach()\n",
    "        loss = 0.5 * (\n",
    "            least_squares_discriminator_loss(dis_a(image_a), dis_a(fake_a))\n",
    "            + least_squares_discriminator_loss(dis_b(image_b), dis_b(fake_b))\n",
    "        )\n",
    "        loss.backward()\n",
    "        optimizer_dis_a.step()\n",
    "        optimizer_dis_b.step()\n",
    "        return loss.item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DEFINING THE CUSTOM TRAINER FOR CYCLEGAN\n",
    "\n",
    "Even though the Trainer has been designed to be as general as possible, it cannot handle arbitrary input data format. However, the current design provides a neat trick to by-pass this shortcoming. The data is handled in 3 different ways:\n",
    "\n",
    "1. If it is a list or tuple, then *real_inputs* stores the first element and *labels* stores the second element. Since we expect these to be tensors we push them to device. This might be troublesome in cases where the input data from the data loader is not a tensor, but even that can be handled.\n",
    "2. If it is a torch Tensor we simply save it in the *real_inputs* and is pushed to the device.\n",
    "3. Now the $3^{rd}$ and the most interesting one. In case any of the above are not satisfied we shall be storing the data in *real_inputs*. Note that we leave this format completely untouched. So you can do anything that you need to do with it. Hence we recommend that if you have custom data use a dictionary to fetch it.\n",
    "\n",
    "Now lets come to the trick I mentioned before. Since we defined our dataset to return `dict`. We are now guaranteed to have the untouched data in a variable named `real_inputs`. So we simply redefine the `train_iter_custom` function which is called everytime before we call the `train_iter`. In this function we shall simply unpack the data into 2 variables `image_a` and `image_b`, exactly what the `train_ops` needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CycleGANTrainer(Trainer):\n",
    "    def train_iter_custom(self):\n",
    "        self.image_a = self.real_inputs[\"A\"].to(self.device)\n",
    "        self.image_b = self.real_inputs[\"B\"].to(self.device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*NB: To generate realistic samples increase the epochs to **100**.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda:0\")\n",
    "    # Use deterministic cudnn algorithms\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    epochs = 10\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "    epochs = 5\n",
    "\n",
    "print(\"Device: {}\".format(device))\n",
    "print(\"Epochs: {}\".format(epochs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`image_batch` will act as a reference and we can visualize its transformation over the course of the model training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_batch = next(iter(dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "network_config = {\n",
    "    \"gen_a2b\": {\n",
    "        \"name\": CycleGANGenerator,\n",
    "        \"args\": {\"image_batch\": image_batch[\"A\"]},\n",
    "        \"optimizer\": {\"name\": Adam, \"args\": {\"lr\": 0.0001, \"betas\": (0.5, 0.999)}},\n",
    "    },\n",
    "    \"gen_b2a\": {\n",
    "        \"name\": CycleGANGenerator,\n",
    "        \"args\": {\"image_batch\": image_batch[\"B\"]},\n",
    "        \"optimizer\": {\"name\": Adam, \"args\": {\"lr\": 0.0001, \"betas\": (0.5, 0.999)}},\n",
    "    },\n",
    "    \"dis_a\": {\n",
    "        \"name\": CycleGANDiscriminator,\n",
    "        \"optimizer\": {\"name\": Adam, \"args\": {\"lr\": 0.0001, \"betas\": (0.5, 0.999)}},\n",
    "    },\n",
    "    \"dis_b\": {\n",
    "        \"name\": CycleGANDiscriminator,\n",
    "        \"optimizer\": {\"name\": Adam, \"args\": {\"lr\": 0.0001, \"betas\": (0.5, 0.999)}},\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "losses = [CycleGANGeneratorLoss(), CycleGANDiscriminatorLoss()]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Another important detail is the last 2 arguments. Since at the time of instantiation we shall be checking if all the variables required by the train_ops are present in the trainer, we need to make sure that the object has some attributes named `image_a` and `image_b`. The trainer stores any keyword argument that it receives, hence this is the simplest way to prevent that error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = CycleGANTrainer(\n",
    "    network_config, losses, device=device, epochs=epochs, image_a=None, image_b=None\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer(dataloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## VISUALIZING THE GENERATED DATA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(16, 16))\n",
    "plt.axis(\"off\")\n",
    "plt.title(\"Generated Images A\")\n",
    "plt.imshow(plt.imread(\"{}/epoch{}_gen_b2a.png\".format(trainer.recon, trainer.epochs)))\n",
    "plt.show()\n",
    "plt.figure(figsize=(16, 16))\n",
    "plt.axis(\"off\")\n",
    "plt.title(\"Generated Images B\")\n",
    "plt.imshow(plt.imread(\"{}/epoch{}_gen_a2b.png\".format(trainer.recon, trainer.epochs)))\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
